{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Read before you start:\n",
    "\n",
    "This notebook gives a test demo for all the LLMs we trained during phase2: Multi-Task Instruction Tuning.\n",
    "\n",
    "- LLMs: Llama2, Falcon, BLOOM, ChatGLM2, Qwen, MPT\n",
    "- Tasks: Sentiment Analysis, Headline Classification, Named Entity Extraction, Relation Extraction\n",
    "\n",
    "All the models & instruction data samples used are also available in our huggingface organization. [https://huggingface.co/FinGPT]\n",
    "\n",
    "Models trained in phase1&3 are not provided, as MT-models cover most of their capacity. Feel free to train your own models based on the tasks you want.\n",
    "\n",
    "Due to the limited diversity of the financial tasks and datasets we used, models might not response correctly to out-of-scope instructions. We'll delve into the generalization ability more in our future works."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# First choose to load data/model from huggingface or local space\n",
    "\n",
    "FROM_REMOTE = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2023-10-15 20:44:54,084] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "from peft import PeftModel\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "# Suppress Warnings during inference\n",
    "logging.getLogger(\"transformers\").setLevel(logging.ERROR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "demo_tasks = [\n",
    "    'Financial Sentiment Analysis',\n",
    "    'Financial Relation Extraction',\n",
    "    'Financial Headline Classification',\n",
    "    'Financial Named Entity Recognition',\n",
    "]\n",
    "demo_inputs = [\n",
    "    \"Glaxo's ViiV Healthcare Signs China Manufacturing Deal With Desano\",\n",
    "    \"Wednesday, July 8, 2015 10:30AM IST (5:00AM GMT) Rimini Street Comment on Oracle Litigation Las Vegas, United States Rimini Street, Inc., the leading independent provider of enterprise software support for SAP AG’s (NYSE:SAP) Business Suite and BusinessObjects software and Oracle Corporation’s (NYSE:ORCL) Siebel , PeopleSoft , JD Edwards , E-Business Suite , Oracle Database , Hyperion and Oracle Retail software, today issued a statement on the Oracle litigation.\",\n",
    "    'april gold down 20 cents to settle at $1,116.10/oz',\n",
    "    'Subject to the terms and conditions of this Agreement , Bank agrees to lend to Borrower , from time to time prior to the Commitment Termination Date , equipment advances ( each an \" Equipment Advance \" and collectively the \" Equipment Advances \").',\n",
    "]\n",
    "demo_instructions = [\n",
    "    'What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.',\n",
    "    'Given phrases that describe the relationship between two words/phrases as options, extract the word/phrase pair and the corresponding lexical relationship between them from the input text. The output format should be \"relation1: word1, word2; relation2: word3, word4\". Options: product/material produced, manufacturer, distributed by, industry, position held, original broadcaster, owned by, founded by, distribution format, headquarters location, stock exchange, currency, parent organization, chief executive officer, director/manager, owner of, operator, member of, employer, chairperson, platform, subsidiary, legal form, publisher, developer, brand, business division, location of formation, creator.',\n",
    "    'Does the news headline talk about price in the past? Please choose an answer from {Yes/No}.',\n",
    "    'Please extract entities and their types from the input sentence, entity types should be chosen from {person/organization/location}.',\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_model(base_model, peft_model, from_remote=False):\n",
    "    \n",
    "    model_name = parse_model_name(base_model, from_remote)\n",
    "\n",
    "    model = AutoModelForCausalLM.from_pretrained(\n",
    "        model_name, trust_remote_code=True, \n",
    "        device_map=\"auto\",\n",
    "    )\n",
    "    model.model_parallel = True\n",
    "\n",
    "    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
    "    \n",
    "    tokenizer.padding_side = \"left\"\n",
    "    if base_model == 'qwen':\n",
    "        tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids('<|endoftext|>')\n",
    "        tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids('<|extra_0|>')\n",
    "    if not tokenizer.pad_token or tokenizer.pad_token_id == tokenizer.eos_token_id:\n",
    "        tokenizer.add_special_tokens({'pad_token': '[PAD]'})\n",
    "        model.resize_token_embeddings(len(tokenizer))\n",
    "    \n",
    "    model = PeftModel.from_pretrained(model, peft_model)\n",
    "    model = model.eval()\n",
    "    return model, tokenizer\n",
    "\n",
    "\n",
    "def test_demo(model, tokenizer):\n",
    "\n",
    "    for task_name, input, instruction in zip(demo_tasks, demo_inputs, demo_instructions):\n",
    "        prompt = 'Instruction: {instruction}\\nInput: {input}\\nAnswer: '.format(\n",
    "            input=input, \n",
    "            instruction=instruction\n",
    "        )\n",
    "        inputs = tokenizer(\n",
    "            prompt, return_tensors='pt',\n",
    "            padding=True, max_length=512,\n",
    "            return_token_type_ids=False\n",
    "        )\n",
    "        inputs = {key: value.to(model.device) for key, value in inputs.items()}\n",
    "        res = model.generate(\n",
    "            **inputs, max_length=512, do_sample=False,\n",
    "            eos_token_id=tokenizer.eos_token_id\n",
    "        )\n",
    "        output = tokenizer.decode(res[0], skip_special_tokens=True)\n",
    "        print(f\"\\n==== {task_name} ====\\n\")\n",
    "        print(output)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Llama2-7B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.006228446960449219,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Loading checkpoint shards",
       "rate": null,
       "total": 2,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0d58aff745fb486780792c86384fe702",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using pad_token, but it is not set yet.\n",
      "/root/.conda/envs/torch2/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:2436: UserWarning: `max_length` is ignored when `padding`=`True` and there is no truncation strategy. To pad to max length, use `padding='max_length'`.\n",
      "  warnings.warn(\n",
      "/root/.conda/envs/torch2/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:362: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
      "  warnings.warn(\n",
      "/root/.conda/envs/torch2/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:367: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "==== Financial Sentiment Analysis ====\n",
      "\n",
      "Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Glaxo's ViiV Healthcare Signs China Manufacturing Deal With Desano\n",
      "Answer:  positive\n",
      "\n",
      "==== Financial Relation Extraction ====\n",
      "\n",
      "Instruction: Given phrases that describe the relationship between two words/phrases as options, extract the word/phrase pair and the corresponding lexical relationship between them from the input text. The output format should be \"relation1: word1, word2; relation2: word3, word4\". Options: product/material produced, manufacturer, distributed by, industry, position held, original broadcaster, owned by, founded by, distribution format, headquarters location, stock exchange, currency, parent organization, chief executive officer, director/manager, owner of, operator, member of, employer, chairperson, platform, subsidiary, legal form, publisher, developer, brand, business division, location of formation, creator.\n",
      "Input: Wednesday, July 8, 2015 10:30AM IST (5:00AM GMT) Rimini Street Comment on Oracle Litigation Las Vegas, United States Rimini Street, Inc., the leading independent provider of enterprise software support for SAP AG’s (NYSE:SAP) Business Suite and BusinessObjects software and Oracle Corporation’s (NYSE:ORCL) Siebel , PeopleSoft , JD Edwards , E-Business Suite , Oracle Database , Hyperion and Oracle Retail software, today issued a statement on the Oracle litigation.\n",
      "Answer:  product_or_material_produced: PeopleSoft, software; parent_organization: Siebel, Oracle Corporation; industry: Oracle Corporation, software; product_or_material_produced: Oracle Corporation, software; product_or_material_produced: Oracle Corporation, software\n",
      "\n",
      "==== Financial Headline Classification ====\n",
      "\n",
      "Instruction: Does the news headline talk about price in the past? Please choose an answer from {Yes/No}.\n",
      "Input: april gold down 20 cents to settle at $1,116.10/oz\n",
      "Answer:  Yes\n",
      "\n",
      "==== Financial Named Entity Recognition ====\n",
      "\n",
      "Instruction: Please extract entities and their types from the input sentence, entity types should be chosen from {person/organization/location}.\n",
      "Input: Subject to the terms and conditions of this Agreement , Bank agrees to lend to Borrower , from time to time prior to the Commitment Termination Date , equipment advances ( each an \" Equipment Advance \" and collectively the \" Equipment Advances \").\n",
      "Answer:  Bank is an organization, Borrower is a person.\n"
     ]
    }
   ],
   "source": [
    "base_model = 'llama2'\n",
    "peft_model = 'FinGPT/fingpt-mt_llama2-7b_lora' if FROM_REMOTE else 'finetuned_models/MT-llama2-linear_202309241345'\n",
    "\n",
    "model, tokenizer = load_model(base_model, peft_model, FROM_REMOTE)\n",
    "test_demo(model, tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Qwen-7B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The model is automatically converting to bf16 for faster inference. If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\".\n",
      "Try importing flash-attention for faster inference...\n",
      "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary\n",
      "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm\n",
      "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency https://github.com/Dao-AILab/flash-attention\n"
     ]
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004647493362426758,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Loading checkpoint shards",
       "rate": null,
       "total": 8,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e1978e69ea784778acd1813cc0647c3e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/.conda/envs/torch2/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:367: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.8` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
      "  warnings.warn(\n",
      "/root/.conda/envs/torch2/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:377: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "==== Financial Sentiment Analysis ====\n",
      "\n",
      "Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Glaxo's ViiV Healthcare Signs China Manufacturing Deal With Desano\n",
      "Answer: positive\n",
      "\n",
      "==== Financial Relation Extraction ====\n",
      "\n",
      "Instruction: Given phrases that describe the relationship between two words/phrases as options, extract the word/phrase pair and the corresponding lexical relationship between them from the input text. The output format should be \"relation1: word1, word2; relation2: word3, word4\". Options: product/material produced, manufacturer, distributed by, industry, position held, original broadcaster, owned by, founded by, distribution format, headquarters location, stock exchange, currency, parent organization, chief executive officer, director/manager, owner of, operator, member of, employer, chairperson, platform, subsidiary, legal form, publisher, developer, brand, business division, location of formation, creator.\n",
      "Input: Wednesday, July 8, 2015 10:30AM IST (5:00AM GMT) Rimini Street Comment on Oracle Litigation Las Vegas, United States Rimini Street, Inc., the leading independent provider of enterprise software support for SAP AG’s (NYSE:SAP) Business Suite and BusinessObjects software and Oracle Corporation’s (NYSE:ORCL) Siebel , PeopleSoft , JD Edwards , E-Business Suite , Oracle Database , Hyperion and Oracle Retail software, today issued a statement on the Oracle litigation.\n",
      "Answer: subsidiary: PeopleSoft, JD Edwards\n",
      "\n",
      "==== Financial Headline Classification ====\n",
      "\n",
      "Instruction: Does the news headline talk about price in the past? Please choose an answer from {Yes/No}.\n",
      "Input: april gold down 20 cents to settle at $1,116.10/oz\n",
      "Answer: Yes\n",
      "\n",
      "==== Financial Named Entity Recognition ====\n",
      "\n",
      "Instruction: Please extract entities and their types from the input sentence, entity types should be chosen from {person/organization/location}.\n",
      "Input: Subject to the terms and conditions of this Agreement , Bank agrees to lend to Borrower , from time to time prior to the Commitment Termination Date , equipment advances ( each an \" Equipment Advance \" and collectively the \" Equipment Advances \").\n",
      "Answer: Bank is an organization, Borrower is a person.\n"
     ]
    }
   ],
   "source": [
    "base_model = 'qwen'\n",
    "peft_model = 'FinGPT/fingpt-mt_qwen-7b_lora' if FROM_REMOTE else 'finetuned_models/MT-qwen-linear_202309221011'\n",
    "\n",
    "model, tokenizer = load_model(base_model, peft_model, FROM_REMOTE)\n",
    "test_demo(model, tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Falcon-7B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004422426223754883,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Loading checkpoint shards",
       "rate": null,
       "total": 2,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e12fadfbaa6048538bbeef26ed563b28",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using pad_token, but it is not set yet.\n",
      "/root/.conda/envs/torch2/lib/python3.9/site-packages/transformers/generation/utils.py:1411: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation )\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "==== Financial Sentiment Analysis ====\n",
      "\n",
      "Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Glaxo's ViiV Healthcare Signs China Manufacturing Deal With Desano\n",
      "Answer: positive\n",
      "\n",
      "==== Financial Relation Extraction ====\n",
      "\n",
      "Instruction: Given phrases that describe the relationship between two words/phrases as options, extract the word/phrase pair and the corresponding lexical relationship between them from the input text. The output format should be \"relation1: word1, word2; relation2: word3, word4\". Options: product/material produced, manufacturer, distributed by, industry, position held, original broadcaster, owned by, founded by, distribution format, headquarters location, stock exchange, currency, parent organization, chief executive officer, director/manager, owner of, operator, member of, employer, chairperson, platform, subsidiary, legal form, publisher, developer, brand, business division, location of formation, creator.\n",
      "Input: Wednesday, July 8, 2015 10:30AM IST (5:00AM GMT) Rimini Street Comment on Oracle Litigation Las Vegas, United States Rimini Street, Inc., the leading independent provider of enterprise software support for SAP AG’s (NYSE:SAP) Business Suite and BusinessObjects software and Oracle Corporation’s (NYSE:ORCL) Siebel, PeopleSoft, JD Edwards, E-Business Suite, Oracle Database, Hyperion and Oracle Retail software, today issued a statement on the Oracle litigation.\n",
      "Answer: product_or_material_produced: PeopleSoft, Oracle Database\n",
      "\n",
      "==== Financial Headline Classification ====\n",
      "\n",
      "Instruction: Does the news headline talk about price in the past? Please choose an answer from {Yes/No}.\n",
      "Input: april gold down 20 cents to settle at $1,116.10/oz\n",
      "Answer: Yes\n",
      "\n",
      "==== Financial Named Entity Recognition ====\n",
      "\n",
      "Instruction: Please extract entities and their types from the input sentence, entity types should be chosen from {person/organization/location}.\n",
      "Input: Subject to the terms and conditions of this Agreement, Bank agrees to lend to Borrower, from time to time prior to the Commitment Termination Date, equipment advances ( each an \" Equipment Advance \" and collectively the \" Equipment Advances \").\n",
      "Answer: Bank is an organization, Borrower is a person.\n"
     ]
    }
   ],
   "source": [
    "base_model = 'falcon'\n",
    "peft_model = 'FinGPT/fingpt-mt_falcon-7b_lora' if FROM_REMOTE else 'finetuned_models/MT-falcon-linear_202309210126'\n",
    "\n",
    "model, tokenizer = load_model(base_model, peft_model, FROM_REMOTE)\n",
    "test_demo(model, tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ChatGLM2-6B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004460573196411133,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Loading checkpoint shards",
       "rate": null,
       "total": 7,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8bddd025a6514946b5f07f55e9c38f58",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "==== Financial Sentiment Analysis ====\n",
      "\n",
      "Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Glaxo's ViiV Healthcare Signs China Manufacturing Deal With Desano\n",
      "Answer:  positive\n",
      "\n",
      "==== Financial Relation Extraction ====\n",
      "\n",
      "Instruction: Given phrases that describe the relationship between two words/phrases as options, extract the word/phrase pair and the corresponding lexical relationship between them from the input text. The output format should be \"relation1: word1, word2; relation2: word3, word4\". Options: product/material produced, manufacturer, distributed by, industry, position held, original broadcaster, owned by, founded by, distribution format, headquarters location, stock exchange, currency, parent organization, chief executive officer, director/manager, owner of, operator, member of, employer, chairperson, platform, subsidiary, legal form, publisher, developer, brand, business division, location of formation, creator.\n",
      "Input: Wednesday, July 8, 2015 10:30AM IST (5:00AM GMT) Rimini Street Comment on Oracle Litigation Las Vegas, United States Rimini Street, Inc., the leading independent provider of enterprise software support for SAP AG’s (NYSE:SAP) Business Suite and BusinessObjects software and Oracle Corporation’s (NYSE:ORCL) Siebel , PeopleSoft , JD Edwards , E-Business Suite , Oracle Database , Hyperion and Oracle Retail software, today issued a statement on the Oracle litigation.\n",
      "Answer:  product_or_material_produced: Oracle, Oracle Database; developer: Oracle, Oracle; product_or_material_produced: Oracle, Oracle Database\n",
      "\n",
      "==== Financial Headline Classification ====\n",
      "\n",
      "Instruction: Does the news headline talk about price in the past? Please choose an answer from {Yes/No}.\n",
      "Input: april gold down 20 cents to settle at $1,116.10/oz\n",
      "Answer:  Yes\n",
      "\n",
      "==== Financial Named Entity Recognition ====\n",
      "\n",
      "Instruction: Please extract entities and their types from the input sentence, entity types should be chosen from {person/organization/location}.\n",
      "Input: Subject to the terms and conditions of this Agreement , Bank agrees to lend to Borrower , from time to time prior to the Commitment Termination Date , equipment advances ( each an \" Equipment Advance \" and collectively the \" Equipment Advances \").\n",
      "Answer:  Bank is an organization, Borrower is a person.\n"
     ]
    }
   ],
   "source": [
    "base_model = 'chatglm2'\n",
    "peft_model = 'FinGPT/fingpt-mt_chatglm2-6b_lora' if FROM_REMOTE else 'finetuned_models/MT-chatglm2-linear_202309201120'\n",
    "\n",
    "model, tokenizer = load_model(base_model, peft_model, FROM_REMOTE)\n",
    "test_demo(model, tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# BLOOM-7B1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004486799240112305,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Loading checkpoint shards",
       "rate": null,
       "total": 2,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "32ee0b5e2df049a0b9e458c779e09a68",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "==== Financial Sentiment Analysis ====\n",
      "\n",
      "Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Glaxo's ViiV Healthcare Signs China Manufacturing Deal With Desano\n",
      "Answer: positive\n",
      "\n",
      "==== Financial Relation Extraction ====\n",
      "\n",
      "Instruction: Given phrases that describe the relationship between two words/phrases as options, extract the word/phrase pair and the corresponding lexical relationship between them from the input text. The output format should be \"relation1: word1, word2; relation2: word3, word4\". Options: product/material produced, manufacturer, distributed by, industry, position held, original broadcaster, owned by, founded by, distribution format, headquarters location, stock exchange, currency, parent organization, chief executive officer, director/manager, owner of, operator, member of, employer, chairperson, platform, subsidiary, legal form, publisher, developer, brand, business division, location of formation, creator.\n",
      "Input: Wednesday, July 8, 2015 10:30AM IST (5:00AM GMT) Rimini Street Comment on Oracle Litigation Las Vegas, United States Rimini Street, Inc., the leading independent provider of enterprise software support for SAP AG’s (NYSE:SAP) Business Suite and BusinessObjects software and Oracle Corporation’s (NYSE:ORCL) Siebel , PeopleSoft , JD Edwards , E-Business Suite , Oracle Database , Hyperion and Oracle Retail software, today issued a statement on the Oracle litigation.\n",
      "Answer: product_or_material_produced: software provider, Software\n",
      "\n",
      "==== Financial Headline Classification ====\n",
      "\n",
      "Instruction: Does the news headline talk about price in the past? Please choose an answer from {Yes/No}.\n",
      "Input: april gold down 20 cents to settle at $1,116.10/oz\n",
      "Answer: Yes\n",
      "\n",
      "==== Financial Named Entity Recognition ====\n",
      "\n",
      "Instruction: Please extract entities and their types from the input sentence, entity types should be chosen from {person/organization/location}.\n",
      "Input: Subject to the terms and conditions of this Agreement , Bank agrees to lend to Borrower , from time to time prior to the Commitment Termination Date , equipment advances ( each an \" Equipment Advance \" and collectively the \" Equipment Advances \").\n",
      "Answer: Bank is an organization, Borrower is a person.\n"
     ]
    }
   ],
   "source": [
    "base_model = 'bloom'\n",
    "peft_model = 'FinGPT/fingpt-mt_bloom-7b1_lora' if FROM_REMOTE else 'finetuned_models/MT-bloom-linear_202309211510'\n",
    "\n",
    "model, tokenizer = load_model(base_model, peft_model, FROM_REMOTE)\n",
    "test_demo(model, tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MPT-7B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/.cache/huggingface/modules/transformers_modules/mpt-7b-peft-compatible/attention.py:148: UserWarning: Using `attn_impl: torch`. If your model does not use `alibi` or `prefix_lm` we recommend using `attn_impl: flash` otherwise we recommend using `attn_impl: triton`.\n",
      "  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')\n",
      "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n"
     ]
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004449605941772461,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Loading checkpoint shards",
       "rate": null,
       "total": 2,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0440bc96112344c493c8a1f5dd76f319",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using pad_token, but it is not set yet.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "==== Financial Sentiment Analysis ====\n",
      "\n",
      "Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Glaxo's ViiV Healthcare Signs China Manufacturing Deal With Desano\n",
      "Answer: positive\n",
      "\n",
      "==== Financial Relation Extraction ====\n",
      "\n",
      "Instruction: Given phrases that describe the relationship between two words/phrases as options, extract the word/phrase pair and the corresponding lexical relationship between them from the input text. The output format should be \"relation1: word1, word2; relation2: word3, word4\". Options: product/material produced, manufacturer, distributed by, industry, position held, original broadcaster, owned by, founded by, distribution format, headquarters location, stock exchange, currency, parent organization, chief executive officer, director/manager, owner of, operator, member of, employer, chairperson, platform, subsidiary, legal form, publisher, developer, brand, business division, location of formation, creator.\n",
      "Input: Wednesday, July 8, 2015 10:30AM IST (5:00AM GMT) Rimini Street Comment on Oracle Litigation Las Vegas, United States Rimini Street, Inc., the leading independent provider of enterprise software support for SAP AG’s (NYSE:SAP) Business Suite and BusinessObjects software and Oracle Corporation’s (NYSE:ORCL) Siebel, PeopleSoft, JD Edwards, E-Business Suite, Oracle Database, Hyperion and Oracle Retail software, today issued a statement on the Oracle litigation.\n",
      "Answer: product_or_material_produced: Hyperion, software\n",
      "\n",
      "==== Financial Headline Classification ====\n",
      "\n",
      "Instruction: Does the news headline talk about price in the past? Please choose an answer from {Yes/No}.\n",
      "Input: april gold down 20 cents to settle at $1,116.10/oz\n",
      "Answer: Yes\n",
      "\n",
      "==== Financial Named Entity Recognition ====\n",
      "\n",
      "Instruction: Please extract entities and their types from the input sentence, entity types should be chosen from {person/organization/location}.\n",
      "Input: Subject to the terms and conditions of this Agreement, Bank agrees to lend to Borrower, from time to time prior to the Commitment Termination Date, equipment advances ( each an \" Equipment Advance \" and collectively the \" Equipment Advances \").\n",
      "Answer: Bank is an organization, Borrower is a person.\n"
     ]
    }
   ],
   "source": [
    "base_model = 'mpt'\n",
    "peft_model = 'FinGPT/fingpt-mt_mpt-7b_lora' if FROM_REMOTE else 'finetuned_models/MT-mpt-linear_202309230221'\n",
    "\n",
    "model, tokenizer = load_model(base_model, peft_model, FROM_REMOTE)\n",
    "test_demo(model, tokenizer)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch2",
   "language": "python",
   "name": "torch2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
